{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# t-SNE Animations\n",
    "\n",
    "*openTSNE* includes a callback system, with can be triggered every *n* iterations and can also be used to control optimization and when to stop.\n",
    "\n",
    "In this notebook, we'll look at an example and use callbacks to generate an animation of the optimization. In practice, this serves no real purpose other than being fun to look at."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from openTSNE import TSNE\n",
    "\n",
    "from examples import utils\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation as animation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gzip\n",
    "import pickle\n",
    "\n",
    "with gzip.open(\"data/macosko_2015.pkl.gz\", \"rb\") as f:\n",
    "    data = pickle.load(f)\n",
    "\n",
    "x = data[\"pca_50\"]\n",
    "y = data[\"CellType1\"].astype(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data set contains 44808 samples with 50 features\n"
     ]
    }
   ],
   "source": [
    "print(\"Data set contains %d samples with %d features\" % x.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We pass a callback that will take the current embedding, make a copy (this is important because the embedding is changed inplace during optimization) and add it to a list. We can also specify how often the callbacks should be called. In this instance, we'll call it at every iteration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = []\n",
    "\n",
    "tsne = TSNE(\n",
    "    perplexity=50, metric=\"cosine\",\n",
    "    # Let's use the fast approximation methods\n",
    "    neighbors=\"approx\", negative_gradient_method=\"fft\",\n",
    "    # The embedding will be appended to the list we defined above, make sure we copy the\n",
    "    # embedding, otherwise the same object reference will be stored for every iteration\n",
    "    callbacks=lambda it, err, emb: embeddings.append(np.array(emb)),\n",
    "    # This should be done on every iteration\n",
    "    callbacks_every_iters=1,\n",
    "    # -2 will use all but one core so I can look at cute cat pictures while this computes\n",
    "    n_jobs=-2\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 16min 13s, sys: 2.1 s, total: 16min 15s\n",
      "Wall time: 2min 45s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TSNEEmbedding([[37.18444416, 18.72773713],\n",
       "               [37.12496354, 18.80951715],\n",
       "               [37.16077576, 18.89134288],\n",
       "               ...,\n",
       "               [40.50806396, 20.87112375],\n",
       "               [ 0.97918374, 13.68780657],\n",
       "               [-6.28309915,  5.94863532]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%time tsne.fit(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have all the iterations in our list, we need to create the animation. We do this here using matplotlib, which is relatively straightforward. Generating the animation can take a long time, so we will save it as a gif so we can come back to it whenever we want, without having to wait again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 6min 47s, sys: 2.41 s, total: 6min 50s\n",
      "Wall time: 8min 18s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "fig = plt.figure(figsize=(7, 7))\n",
    "ax = fig.add_axes([0, 0, 1, 1])\n",
    "ax.set_xticks([]), ax.set_yticks([])\n",
    "\n",
    "colors = list(map(utils.MACOSKO_COLORS.get, y))\n",
    "pathcol = ax.scatter(embeddings[0][:, 0], embeddings[0][:, 1], c=colors, s=1, rasterized=True)\n",
    "\n",
    "def update(embedding, ax, pathcol):\n",
    "    # Update point positions\n",
    "    pathcol.set_offsets(embedding)\n",
    "    \n",
    "    # Adjust x/y limits so all the points are visible\n",
    "    ax.set_xlim(np.min(embedding[:, 0]), np.max(embedding[:, 0]))\n",
    "    ax.set_ylim(np.min(embedding[:, 1]), np.max(embedding[:, 1]))\n",
    "    \n",
    "    return [pathcol]\n",
    "\n",
    "anim = animation.FuncAnimation(\n",
    "    fig, update, fargs=(ax, pathcol), interval=20,\n",
    "    frames=embeddings, blit=True,\n",
    ")\n",
    "\n",
    "anim.save(\"macosko.gif\", dpi=60, writer=\"imagemagick\")\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
